% author: ziyan zhu (zzhu1@g.harvard.edu) 
% extract theta_I based on the local misfit energy distribution (average
% AAA spots separation). 
clear all 
clf
f_size = 22;
set(groot, 'DefaultTextInterpreter', 'Latex')
set(groot, 'DefaultLegendInterpreter', 'Latex')
set(groot, 'DefaultAxesTickLabelInterpreter', 'Latex')
set(0,'DefaultAxesFontSize',f_size)

system = 'triG';
tiledlayout(1,2)
set(gcf, 'Position', [110 364 1082 433])


for run_idx = 1:2
    %  run 1: keeping theta_TM fixed at 1.5 degrees, changing theta_MB and
    %  calculate the theta_I vs. the moire length 
    if run_idx == 1 
        % also have data for 1.62 and 1.6, but the twist angles are too small, need
        % a higher resolution to resolve AAA
        q23_list = [1.65, 1.67, 1.69, 1.71, 1.73, 1.75, 1.77, 1.8, 1.9, 2.0];
        q12_list = linspace(1, 1, length(q23_list)) * 1.5;
    % run 2: keeping the difference between the theta_TM and theta_MB fixed
    % plotting the intrinsic twist angle 
    else
        q12_list = [1.3, 1.48, 1.49, 1.5, 1.51, 1.52, 1.6, 1.7, 1.8, 1.9, 2.1, 2.5];
        q23_list = q12_list + 0.19;
    end 

        a0 = sqrt(3)*1.43;
        

    for q_idx = 1:length(q23_list)

        fprintf("Loading data %d/%d \n", q_idx, length(q23_list))
        q23 = deg2rad(q23_list(q_idx));
        q12 = deg2rad(q12_list(q_idx));

        disp([q12, q23]);
        A0 = a0*[1 1/2;
                 0 sqrt(3)/2];
        % moire of moire length 
        Am_dom = calc_moire_dom_tri(q12, q23, A0);
        disp(['The moire of moire length in is ' num2str(norm(Am_dom(:,1))/10) ' nm' ])
        moire(q_idx) = norm(Am_dom(:,1))/10;


        fname = ['local_twist_q12_' num2str(rad2deg(q12)) '_q23_' num2str(rad2deg(q23)) '.mat'];

        % check if file exists; if not, calculate from scratch (based on
        % the outputs from example.jl
        if exist(['./data/' fname ], 'file')
            load(['./data/' fname ])
            disp('loading existing data...')
        else 
            disp('calculating the local twist angles...')
            masterdir = '/data/';
            
            
            N = 54;
            scale = 20; % adjust accordingly
            energy_data = importdata(['.' masterdir 'triG_q12_' num2str(rad2deg(q12)) ...
                'deg_q23_' num2str(rad2deg(q23))...
                'deg_N_' num2str(N) '_scale_' num2str(scale) '_disp_energy.txt']);

            xarr = energy_data.data(:, 1);
            yarr = energy_data.data(:, 2);
            energy = energy_data.data(:,3:4);

            n = 500;

            X = linspace(min(xarr), max(xarr), 4*n);
            Y = linspace(min(yarr), max(yarr), 5*n);

            xtmp = X' .* ones(1,5*n);
            ytmp = ones(4*n,1) .* reshape(Y, [1,5*n]);

            misfit12 = reshape(energy(:,1), size(xtmp))*1e3;
            misfit23 = reshape(energy(:,2), size(ytmp))*1e3;

        %%

            % downsampling 
            clear xy 

            if rad2deg(q23) <= 1.69 && rad2deg(q23) > 1.62
                int = 3;
                fac = 5;
            elseif rad2deg(q23) < 1.62

                int = 1;
                fac = 10;
            else  
                int = 3;
                fac = 5;
            end 


            misfit12_down = misfit12(1:int:end, 1:int:end); 
            misfit23_down = misfit23(1:int:end, 1:int:end);
            xarr_down = xtmp(1:int:end, 1:int:end);
            yarr_down = ytmp(1:int:end, 1:int:end);


            misfit12_down = misfit12(1:int:end, 1:int:end); 
            misfit23_down = misfit23(1:int:end, 1:int:end);
            xarr_down = xtmp(1:int:end, 1:int:end);
            yarr_down = ytmp(1:int:end, 1:int:end);

            misfit_down = misfit12_down + misfit23_down;

            thres = max(misfit12_down(:))*0.8; 
            x_max12 = xarr_down(misfit12_down > thres); 
            y_max12 = yarr_down(misfit12_down > thres);
            xy(1,:) = xarr_down(:); 
            xy(2,:) = yarr_down(:);

            thres = max(misfit23_down(:))*0.92; 
            x_max23 = xarr_down(misfit23_down > thres); 
            y_max23 = yarr_down(misfit23_down > thres);


            if rad2deg(q23) <= 1.69 && rad2deg(q23) > 1.62
                thres = max(misfit_down(:))*0.4;

            elseif rad2deg(q23) < 1.62
                thres = max(misfit_down(:))*0.35;

            else 
                thres = max(misfit_down(:))*0.35; 
            end 

            x_max = xarr_down(misfit_down > thres); 
            y_max = yarr_down(misfit_down > thres);

            % discard points that are in close proximity 
            % misfit12 
            idx = 1;
            xmax_tmp = [];
            ymax_tmp = [];
            xint = mean(diff(xarr_down(:,1)));
            yint = mean(diff(yarr_down(1,:)));
            r = sqrt(xint^2 + yint^2);
            dis = 0;

            misfit12_tmp = misfit12_down(:);

            for x_idx = 1:length(x_max12)
               xy_here = [x_max12(x_idx); y_max12(x_idx)];

               if length(xmax_tmp) > 0
                   vec = xy_here - [xmax_tmp; ymax_tmp];
                   dis = sqrt(vec(1,:).^2 + vec(2,:).^2);
               end 
               if min(dis) > fac*r || length(xmax_tmp) == 0

                   % find the maximum of the cluster
                   vec = xy_here - xy;
                   dis = sqrt(vec(1,:).^2 + vec(2,:).^2);
                   x_cluster = xy( 1, dis < fac*r);
                   y_cluster = xy( 2, dis < fac*r);

                   misfit_cluster =  misfit12_tmp( dis < fac*r );
                   [~, a] = max(misfit_cluster);
                   
                   xmax_tmp(idx) = x_cluster(a);
                   ymax_tmp(idx) = y_cluster(a);
                   idx = idx + 1;
               end 
            end 



            x_max12 = xmax_tmp; 
            y_max12 = ymax_tmp; 

            % misfit23
            idx = 1;
            xmax_tmp = [];
            ymax_tmp = [];
            xint = mean(diff(xarr_down(:,1)));
            yint = mean(diff(yarr_down(1,:)));
            dis = 0;
            misfit23_tmp = misfit23_down(:);

            for x_idx = 3:length(x_max23)
               xy_here = [x_max23(x_idx); y_max23(x_idx)];

               if length(xmax_tmp) > 0
                   vec = xy_here - [xmax_tmp; ymax_tmp];
                   dis = sqrt(vec(1,:).^2 + vec(2,:).^2);
               end 
               if min(dis) > fac*r || length(xmax_tmp) == 0
                   vec = xy_here - xy;
                   dis = sqrt(vec(1,:).^2 + vec(2,:).^2);
                   x_cluster = xy( 1, dis < fac*r);
                   y_cluster = xy( 2, dis < fac*r);

                   misfit_cluster =  misfit23_tmp( dis < fac*r );
                   [~, a] = max(misfit_cluster);
            %        disp(a)

                   xmax_tmp(idx) = x_cluster(a);
                   ymax_tmp(idx) = y_cluster(a);
                   idx = idx + 1;
               end 
            end 

            x_max23 = xmax_tmp; 
            y_max23 = ymax_tmp; 

            % total energy 
            idx = 1;
            xmax_tmp = [];
            ymax_tmp = [];
            xint = mean(diff(xarr_down(:,1)));
            yint = mean(diff(yarr_down(1,:)));
            dis = 0;
            misfit_tmp = misfit12_tmp + misfit23_tmp;

            for x_idx = 3:length(x_max)
               xy_here = [x_max(x_idx); y_max(x_idx)];

               if length(xmax_tmp) > 0
                   vec = xy_here - [xmax_tmp; ymax_tmp];
                   dis = sqrt(vec(1,:).^2 + vec(2,:).^2);
               end 
               if min(dis) > fac*r || length(xmax_tmp) == 0
                   vec = xy_here - xy;
                   dis = sqrt(vec(1,:).^2 + vec(2,:).^2);
                   x_cluster = xy( 1, dis < fac*r);
                   y_cluster = xy( 2, dis < fac*r);

                   misfit_cluster =  misfit_tmp( dis < fac*r );
                   [~, a] = max(misfit_cluster);
            %        disp(a)

                   xmax_tmp(idx) = x_cluster(a);
                   ymax_tmp(idx) = y_cluster(a);
                   idx = idx + 1;
               end 
            end 

            x_max = xmax_tmp; 
            y_max = ymax_tmp; 

            %%

            % find the distance from a given point to a nearby point (in the first
            % quadrant) 
            r_cut = 1000;
            lamdba_list = [];
            idx = 1;

            clear x_list y_list lambda_list 

            for i = 1 :length(x_max) 
                xy_here = [x_max(i); y_max(i)];

                cond = (x_max > xy_here(1) + 2*r) & (y_max > xy_here(2) + 2*r);
                if sum(cond(:)) > 1
                    xy_to = [ x_max( cond ); y_max( cond )];

                    vec = xy_to - xy_here; 
                    dis = sqrt(vec(1,:).^2 + vec(2,:).^2);
                    dis = dis(dis > 50);
                    [r_min, a] = min(dis);
                    if r_min < r_cut 
                        lambda_list(idx) = r_min;
                        x_list(idx) = 0.5*(x_max(i)+xy_to(1,a));
                        y_list(idx) = 0.5*(y_max(i)+xy_to(2,a));
                        x_list(idx) = x_max(i);
                        y_list(idx) = y_max(i);

                        idx = idx + 1;

                    end 
                end 
            end 
            DT = delaunay(x_list,y_list);
            dt = delaunayTriangulation(x_list',y_list');

            [Xt, Yt, A] = triarea(DT, [x_list; y_list]');

            xcoord = mean(Xt,2);
            ycoord = mean(Yt,2);

            A_rescale = A/max(A);

            cmap = jet(500);
            idx = floor(A_rescale)*length(cmap);
            idx(idx < 1) = idx(idx < 1) + 1;
            colors = cmap(idx, :);

            lambda = sqrt(4*A/sqrt(3));
            thetas = rad2deg(a0./lambda);

        T = delaunay(xcoord, ycoord);
        
        fname = ['local_twist_q12_' num2str(rad2deg(q12)) '_q23_' num2str(rad2deg(q23)) '.mat'];
        save(['./data/' fname ], 'thetas', 'lambda', 'T', 'xcoord', 'ycoord');


        fname = ['q12_' num2str(rad2deg(q12)) '_q23_' num2str(rad2deg(q23)) 'energy_gsfex10.mat'];
        save(['./data_amp/' fname ], 'xtmp', 'ytmp', 'misfit12', 'misfit23');



        end 

        theta = thetas( thetas < 3);
        
        if q23_list(q_idx) == 1.65 || q23_list(q_idx) == 2.0
            nbins=80; 
        else 
            nbins=60;
        end
        
        [bins{q_idx},centers{q_idx}]=hist(theta,nbins);
    end 

    for q_idx = 1:length(q23_list) 
        bins_norm = bins{q_idx}/max(bins{q_idx});
        err(q_idx) = mean(diff(centers{q_idx}));
        
        max12 = islocalmax(bins_norm);
        xx = centers{q_idx};
        xx = xx(max12);
        yy = bins_norm(max12);

        if q23_list(q_idx) == 1.9
            xx = xx(yy > 0.4);
            yy = yy(yy > 0.4);
        elseif q23_list(q_idx) == 2.0
            xx = xx(yy > 0.96);
            yy = yy(yy > 0.96);
        elseif q23_list(q_idx) ~= 1.6
            xx = xx(yy > 0.09);
            yy = yy(yy > 0.09);
        else 
            xx = xx(yy > 0.05);
            yy = yy(yy > 0.05);

        end 

        [~,a] = max(yy); 
        [~,b] = min(yy (xx < xx(a)) );

        if abs(xx(a)-xx(b))>1e-3
            theta1(q_idx) = xx(a);
            theta2(q_idx) = xx(b);
        else 
            theta1(q_idx) = xx(a);
            theta2(q_idx) = NaN;
        end 

    end 
    
 
    nexttile
    
    if run_idx == 1
        hold all
        errorbar(moire, theta2, err, 'o', 'Linewidth', 1.5);
        box on
        xlabel('$\Lambda$ (nm)');
        ylabel('$\theta_I (^\circ)$');
        xlim([min(moire)-10 max(moire)+10])
        fig = gcf; 

    elseif run_idx == 2 
        hold all; 
        errorbar(q12_list, theta2, err, 'o', 'Linewidth', 1.5);
        theta_list = linspace(min(q12_list), max(q12_list));
        plot(theta_list, theta_list, 'k--', 'LineWidth', 1.5);
        box on
        xlabel('$\theta_{TM}$');
        fig = gcf;
    end 
end 
    
function [Xt, Yt, A] = triarea(t, p)
    % A = TRIAREA(t, p) area of triangles in triangulation
    Xt = reshape(p(t, 1), size(t)); % X coordinates of vertices in triangulation
    Yt = reshape(p(t, 2), size(t)); % Y coordinates of vertices in triangulation
    A = 0.5 * abs((Xt(:, 2) - Xt(:, 1)) .* (Yt(:, 3) - Yt(:, 1)) - ...
        (Xt(:, 3) - Xt(:, 1)) .* (Yt(:, 2) - Yt(:, 1)));
   
end 
